
#

import os
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math


class FC_Net(nn.Module):
    def __init__(self, in_size: int, out_size: int):
        super(FC_Net, self).__init__()
        self.in_size = in_size
        self.out_size = out_size
        
        self.fc1 = nn.Linear(self.in_size, self.out_size)

    def forward(self, input: Tensor):
        out = self.fc1(input)
        return out

### test 
if __name__ == '__main__':
    in_size = 512  # 输入尺寸
    out_size= 101  # 输出尺寸

    net_fc = FC_Net(in_size=in_size, out_size=out_size) # 创建网络
    
    # input = torch.ones(1, in_size)
    # out = net_fc(input)
    # print("out.shape:", out.shape)
    # print("out:\n", out)
    
    # 加载待转换的权重
    weight_path = r"D:\\Works\\Hisi_codes\\SLR-master\\res50lstm_models\\res50lstm_epoch026-fc.pth"
    print("Load params from : ", weight_path)
    model = torch.load(weight_path)

    weights_keys = [] # 待转换权重的键名
    for k in model.keys(): # 查看每个键
        print(k)        # 键名称
        weights_keys.append(k)
        # print(model[k]) # 键值

    # 将网络权重赋给当前fc层
    net_fc.fc1.weight = Parameter(model[weights_keys[0]])
    net_fc.fc1.bias = Parameter(model[weights_keys[1]])

    # 查看整个网络的参数
    print("naive_lstm.state_dict():\n", net_fc.state_dict())

    # 保存
    print("save....")
    model_path = r"D:\\Works\\Hisi_codes\\SLR-master\\res50lstm_models"
    torch.save(net_fc.state_dict(), os.path.join(model_path, "net_last_fc.pth"))

